import argparse
import torch
import rerun as rr
import numpy as np
import pypose as pp
from pathlib import Path

from DataLoader import SequenceBase, StereoFrame, smart_transform
from Evaluation.EvalSeq import EvaluateSequences
from Odometry.MACVO import MACVO

from Utility.Config import load_config, asNamespace
from Utility.PrettyPrint import print_as_table, ColoredTqdm, Logger
from Utility.Sandbox import Sandbox
from Utility.Visualize import fig_plt, rr_plt
from Utility.Timer import Timer


def VisualizeRerunCallback(frame: StereoFrame, system: MACVO, pb: ColoredTqdm):
    rr.set_time_sequence("frame_idx", frame.frame_idx)
    
    # Non-key frame does not need visualization
    if system.graph.frames.data["need_interp"][-1]: return
    
    if frame.frame_idx > 0:    
        rr_plt.log_trajectory("/world/est", pp.SE3(system.graph.frames.data["pose"].tensor))
    
    rr_plt.log_camera("/world/macvo/cam_left", pp.SE3(system.graph.frames.data["pose"][-1]), system.graph.frames.data["K"][-1])
    rr_plt.log_image ("/world/macvo/cam_left", frame.stereo.imageL[0].permute(1, 2, 0))
    
    map_points = system.graph.get_frame2map(system.graph.frames[-1:])
    rr_plt.log_points("/world/point_cloud", map_points.data["pos_Tw"], map_points.data["color"], map_points.data["cov_Tw"], "sphere")
    
    vo_points  = system.graph.get_match2point(system.graph.get_frame2match(system.graph.frames[-1:]))
    rr_plt.log_points("/world/vo_tracking", vo_points.data["pos_Tw"], vo_points.data["color"], vo_points.data["cov_Tw"], "sphere")
    

def VisualizeVRAMUsage(frame: StereoFrame, system: MACVO, pb: ColoredTqdm):
    if torch.cuda.is_available():
        allocated_memory = torch.cuda.memory_reserved(0) / 1e9  # Convert to GB
        allocated_memory = f"{round(allocated_memory, 3)} GB"
    else:
        allocated_memory = "N/A"
    
    pb.set_description(desc=f"{system.graph}, VRAM={allocated_memory}")

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--odom", type=str, default = "Config/Experiment/MACVO/MACVO.yaml")
    parser.add_argument("--data", type=str, default = "Config/Sequence/TartanAir_abandonfac_001.yaml")
    parser.add_argument(
        "--seq_to",
        type=int,
        default=None,
        help="Crop sequence to frame# when ran. Set to -1 (default) if wish to run whole sequence",
    )
    parser.add_argument(
        "--seq_from",
        type=int,
        default=0,
        help="Crop sequence from frame# when ran. Set to 0 (default) if wish to start from first frame",
    )
    parser.add_argument(
        "--resultRoot",
        type=str,
        default="./Results",
        help="Directory to store trajectory and files generated by the script."
    )
    parser.add_argument(
        "--useRR",
        action="store_true",
        help="Activate RerunVisualizer to generate <config.Project>.rrd file for visualization.",
    )
    parser.add_argument(
        "--saveplt",
        action="store_true",
        help="Activate PLTVisualizer to generate <frame_idx>.jpg file in space folder for covariance visualization.",
    )
    parser.add_argument(
        "--preload",
        action="store_true",
        help="Preload entire trajectory into RAM to reduce data fetching overhead during runtime."
    )
    parser.add_argument(
        "--autoremove",
        action="store_true",
        help="Cleanup result sandbox after script finishs / crashed. Helpful during testing & debugging."
    )
    parser.add_argument(
        "--noeval", 
        action="store_true",
        help="Evaluate sequence after running odometry."
    )
    parser.add_argument(
        "--timing",
        action="store_true",
        help="Record timing for system (active Utility.Timer for global time recording)"
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()

    # Metadata setup & visualizer setup
    cfg, cfg_dict = load_config(Path(args.odom))
    odomcfg, odomcfg_dict = cfg.Odometry, cfg_dict["Odometry"]
    datacfg, datacfg_dict = load_config(Path(args.data))
    project_name = odomcfg.name + "@" + datacfg.name

    exp_space = Sandbox.create(Path(args.resultRoot), project_name)
    if args.autoremove: exp_space.set_autoremove()
    exp_space.config = {
        "Project": project_name,
        "Odometry": odomcfg_dict,
        "Data": {"args": datacfg_dict, "end_idx": args.seq_to, "start_idx": args.seq_from},
    }

    # Setup logging and visualization
    if args.useRR:
        rr_plt.default_mode = "rerun"
        rr_plt.init_connect(project_name)
    
    Timer.setup(active=args.timing)
    fig_plt.default_mode = "image" if args.saveplt else "none"

    def onFrameFinished(frame: StereoFrame, system: MACVO, pb: ColoredTqdm):
        VisualizeRerunCallback(frame, system, pb)
        VisualizeVRAMUsage(frame, system, pb)

    # Initialize data source
    sequence = smart_transform(
        SequenceBase[StereoFrame].instantiate(datacfg.type, datacfg.args).clip(args.seq_from, args.seq_to),
        cfg.Preprocess
    )
    
    if args.preload:
        sequence = sequence.preload()
    
    system = MACVO[StereoFrame].from_config(asNamespace(exp_space.config))
    system.receive_frames(sequence, exp_space, on_frame_finished=onFrameFinished)
    
    rr_plt.log_trajectory("/world/est"  , torch.tensor(np.load(exp_space.path("poses.npy"))[:, 1:]))
    try:
        rr_plt.log_points    ("/world/point_cloud", 
                                system.get_map().map_points.data["pos_Tw"].tensor,
                                system.get_map().map_points.data["color"].tensor,
                                system.get_map().map_points.data["cov_Tw"].tensor,
                                "color")
    except RuntimeError:
        Logger.write("warn", "Unable to log full pointcloud - is mapping mode on?")
    
    Timer.report()
    Timer.save_elapsed(exp_space.path("elapsed_time.json"))

    if not args.noeval:
        header, result = EvaluateSequences([str(exp_space.folder)], correct_scale=False)
        print_as_table(header, result)
